/*
 * MIT License
 *
 * Copyright (c) 2020 wen.gu <454727014@qq.com>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

 /***************************************************************************
 * Name: ipc_socket_msvc.cpp
 *
 * Purpose: implementation a socket wrapper for ipc(unix domain socket)
 *
 * Developer:
 *   wen.gu , 2021-10-12
 *
 * TODO:
 *
 ***************************************************************************/

 /******************************************************************************
 **    INCLUDES
 ******************************************************************************/
#if defined(_MSC_VER)

#include "ipc_socket.h"

#include <WinSock2.h>
#include <afunix.h>

#include <Windows.h>

#include <thread>
#include <map>

#pragma comment(lib,"wsock32.lib")

#define LOG_TAG "ipcs"
#include "icpp/core/log.h"

namespace icpp
{
namespace com
{
/******************************************************************************
 **    MACROS
 ******************************************************************************/
#define SOCK_PATH_PREFIX "/tmp/"
#define MAX_CLIENT_CONNECT_LIMIT 128

/** is the size of socket receive cache buffer */
#define SOCK_RECV_SIZE 1500

/** how many second will be block when call select */
#define SELECT_TIME_OUT_SEC 1

/** the timeout time for connect to service, default: 1 sec */
#define CONNECT_TIMEOUT_MS (1 * 1000)
/******************************************************************************
 **    VARIABLE DEFINITIONS
 ******************************************************************************/
using socket_t = SOCKET;

struct SocketInfo
{
    socket_t socket_id;
    int32_t name_len;
    struct sockaddr_un sock_addr;
};

class SocketInitializer
{
protected:
    SocketInitializer()
    {        
        int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
    }
public:
    ~SocketInitializer()
    {
        WSACleanup();
    }

public:
    static SocketInitializer& get()
    {
        static SocketInitializer si;
        return si;
    }

private:
    WSADATA wsaData;
};

class IpcSocketService::Impl
{
public:
    using IpcClientInfoMap = std::map<socket_t, std::string>;
public:
    bool is_running_ = false;
    bool is_started_ = false;
    std::thread* receive_tread_ = nullptr;
    socket_t socket_id_ = INVALID_SOCKET;
    std::string url_addr_;
    OnAddClientHandler on_add_client_;
    OnDelClientHandler on_del_client_;
    IpcClientInfoMap client_infos_;
public:
    core::IcppErrc start(OnAddClientHandler on_add_client, OnDelClientHandler on_del_client, OnReceiveHandler on_receive);
    core::IcppErrc stop();

    core::IcppErrc write(int client_id, const uint8_t* buf, uint32_t len);

protected:
    core::IcppErrc startReceive(OnReceiveHandler on_receive);
    core::IcppErrc stopReceive();
    void addClient(socket_t socket_id, const char* client_addr);
    void delClient(socket_t socket_id);

    socket_t fillFdSet(fd_set& set);

    void onClientConnection(fd_set& sets);
    void onDataReceive(fd_set& sets, OnReceiveHandler on_receive);
};

class IpcSocketClient::Impl
{
public:
    bool is_running_ = false;
    bool is_connected_ = false;
    bool is_auto_connect_ = false;
    std::thread* receive_tread_ = nullptr;
    socket_t socket_id_ = INVALID_SOCKET;
    std::string url_addr_;
    OnConnectHandler on_connect_;
public:
    core::IcppErrc start(OnConnectHandler on_connect, OnReceiveHandler on_receive);
    core::IcppErrc stop();

protected:
    //core::IcppErrc bind
    core::IcppErrc startReceive(OnReceiveHandler on_receive);
    core::IcppErrc stopReceive();
    void onConnectStateChange(bool is_connected);  
    void onReconnect(int32_t timeout_ms);  
    core::IcppErrc doConnect(int32_t timeout_ms);
};


/******************************************************************************
 **    inner FUNCTION DEFINITIONS
 ******************************************************************************/

static int32_t MakeUrlAddress(struct sockaddr_un& sock_addr, const std::string& addr, bool is_service, void* p_sock)
{
    memset(&sock_addr, 0, sizeof(sock_addr));            /* fill socket address structure with our address */
    sock_addr.sun_family = AF_UNIX;
    if (is_service)
    {
        snprintf(sock_addr.sun_path, sizeof(struct sockaddr_un), SOCK_PATH_PREFIX "%s", addr.c_str());     
    }
    else
    {
        time_t timeTotal = time(NULL);
        struct tm* ptm = localtime(&timeTotal);
        /**
         * client name syntax : | service addr | "_client_" | pid | year | month | day | hour | minute | secpmd | instance point |
         */
        snprintf(sock_addr.sun_path, sizeof(struct sockaddr_un), SOCK_PATH_PREFIX "%s_client_%d_%d/%d/%d:%d:%d:%d_%p", 
            addr.c_str(), _getpid(), ptm->tm_year, ptm->tm_mon, ptm->tm_mday, ptm->tm_hour, ptm->tm_min, ptm->tm_sec, p_sock);        
    }

    return offsetof(struct sockaddr_un, sun_path) + strlen(sock_addr.sun_path) + 1;
}


static core::IcppErrc SocketCreate(const std::string& url_addr, bool is_service, SocketInfo& sock_info, void* instance_ptr = nullptr)
{
   SocketInitializer::get();

    if (sock_info.socket_id != INVALID_SOCKET)
    {
        LOGE("serivce socket already exist, please destroy old one before create operation\n");
        return core::IcppErrc::Undefined;
    }
    if ((sock_info.socket_id = socket(AF_UNIX, SOCK_STREAM, 0)) < 0)
    {
        LOGE("create AF_UNIX socket failed\n");
        sock_info.socket_id = INVALID_SOCKET;
        return core::IcppErrc::Undefined;
    }

    struct sockaddr_un& sock_addr = sock_info.sock_addr;
    //struct sockaddr_un sock_addr;
    memset(&sock_addr, 0, sizeof(sock_addr));            /* fill socket address structure with our address */
    sock_addr.sun_family = AF_UNIX;

    sock_info.name_len = MakeUrlAddress(sock_addr, url_addr, is_service, instance_ptr);
    printf("%s addr:%s, len:%d\n", is_service ? "service" : "client", sock_addr.sun_path, sock_info.name_len);

    int32_t ret = bind(sock_info.socket_id, (struct sockaddr*) & sock_addr, sock_info.name_len);
    if (ret < 0)
    {
        LOGE("bind service socket:%s, failed: %ld\n", sock_addr.sun_path, WSAGetLastError());
        closesocket(sock_info.socket_id);
        sock_info.socket_id = INVALID_SOCKET;
        return core::IcppErrc::Undefined;
    }

    return core::IcppErrc::OK;
}

static void SocketClose(socket_t& sock_id)
{
    if (sock_id != INVALID_SOCKET)
    {
        shutdown(sock_id, SD_BOTH);
        closesocket(sock_id);
        sock_id = INVALID_SOCKET;  /** todo refine me?? */
    }    
}

static core::IcppErrc SocketSend(socket_t sock_id, const uint8_t* buf, uint32_t len)
{
    while (len > 0)
    {
        int32_t sendSize = (len >= 1500) ? 1500 : len;   /** is max socket send cache buffer size?? */

        int32_t retSize = ::send(sock_id, (char*)buf, sendSize, 0);
        if (retSize > 0)
        {
            len -= retSize;
            buf += retSize;
        }
        else if (retSize == 0)
        {
            return core::IcppErrc::Undefined;
        }
        else
        {
            if (errno != EINTR)
            {
                return core::IcppErrc::Undefined;
            }
            /** todo something */
        }
    }

    return core::IcppErrc::OK;    
}

/******************************************************************************
 **    FUNCTION DEFINITIONS
 ******************************************************************************/

core::IcppErrc IpcSocketService::Impl::start(OnAddClientHandler on_add_client, OnDelClientHandler on_del_client, OnReceiveHandler on_receive)
{
    if (is_started_)
    {
        return core::IcppErrc::InvalidStatus;
    }

    SocketInfo sock_info;
    core::IcppErrc ret = SocketCreate(url_addr_, true, sock_info);

    if (core::IcppErrc::OK == ret)
    {
        socket_id_ = sock_info.socket_id;
        int res = listen(socket_id_, MAX_CLIENT_CONNECT_LIMIT);

        if (res < 0)
        {
            LOGE("listen socket:%s, failed:%ld\n", sock_info.sock_addr.sun_path, WSAGetLastError());
            SocketClose(socket_id_);
            return core::IcppErrc::Undefined;
        } 

        on_add_client_ = on_add_client;
        on_del_client_ = on_del_client;  

        return startReceive(on_receive);  
    }

    return ret;
}

core::IcppErrc IpcSocketService::Impl::stop()
{
    if (!is_started_)
    {
        return core::IcppErrc::InvalidStatus;
    }

    core::IcppErrc ret = stopReceive();

    if (core::IcppErrc::OK == ret)
    {
        for (const auto& it : client_infos_)
        {
            socket_t sock_id = it.first;
            SocketClose(sock_id);
        }

        client_infos_.clear();
    }

    SocketClose(socket_id_);
    is_started_ = false; 

    return ret; 
}

core::IcppErrc IpcSocketService::Impl::write(int client_id, const uint8_t* buf, uint32_t len)
{
    if (client_id >= 0)
    {
        core::IcppErrc ret = SocketSend(client_id, buf, len);
        if (ret != core::IcppErrc::OK)
        {/** client maybe disconnected */
            delClient(client_id);
        }

        return ret;
    }
    else
    {
        /** for broadcast mode */
        for (const auto& it : client_infos_)
        {
            core::IcppErrc ret = SocketSend(it.first, buf, len);
            if (ret != core::IcppErrc::OK)
            {/** client maybe disconnected */
                delClient(it.first);
            }
        }

        /** todo refine me??  is need to proccess the error of send to client */
    }

    return core::IcppErrc::OK; 
}

core::IcppErrc IpcSocketService::Impl::startReceive(OnReceiveHandler on_receive)
{
    if (is_running_ == false)
    {
        is_running_ = true;
        
        receive_tread_ = new std::thread([this, on_receive]()
        {
            uint8_t recvBuf[2048];

            struct timeval timeout;  //timeout time    
            timeout.tv_sec = SELECT_TIME_OUT_SEC;   /** */
            timeout.tv_usec = 0;
            while (this->is_running_)
            {  
                fd_set server_fd_set;
                socket_t max_fd = fillFdSet(server_fd_set);
                
                int32_t ret = select(max_fd + 1, &server_fd_set, NULL, NULL, &timeout);

                if (ret > 0)
                {
                    onClientConnection(server_fd_set);
                    onDataReceive(server_fd_set, on_receive);
                }
                else if (ret == 0)
                {
                    /** timeout, todo something */
                }
                else /** ret < 0 */
                {
                    LOGE("srv:%s, do select failed:%ld\n", url_addr_.c_str(), WSAGetLastError());
                    /** todo something */
                }
            }            
        });
    }
    
    return core::IcppErrc::OK;
}

core::IcppErrc IpcSocketService::Impl::stopReceive()
{
    if (is_running_)
    {
        if (receive_tread_)
        {
            is_running_ = false;
            receive_tread_->join();  
            delete receive_tread_;
            receive_tread_ = nullptr;          
        }
    }

    return core::IcppErrc::OK;
}



void IpcSocketService::Impl::addClient(socket_t socket_id, const char* client_addr)
{
    if (socket_id != INVALID_SOCKET)
    {
        if (on_add_client_)
        {
            on_add_client_(socket_id, client_addr);
        }

        client_infos_[socket_id] = client_addr;
    }
}

void IpcSocketService::Impl::delClient(socket_t socket_id)
{
    IpcClientInfoMap::iterator it = client_infos_.find(socket_id);

    if (it != client_infos_.end())
    {
        if (on_del_client_)
        {
            on_del_client_(socket_id);
        }
        
        client_infos_.erase(it);
    }
}

socket_t IpcSocketService::Impl::fillFdSet(fd_set& set)
{
    socket_t max_fd = socket_id_;
    FD_ZERO(&set);
    FD_SET(socket_id_, &set);


    for (const auto& it : client_infos_)
    {
        socket_t sock_id = it.first;
        FD_SET(sock_id, &set);
        if (sock_id > max_fd)
        {
            max_fd = sock_id;
        }
    }

    return max_fd;    
}

void IpcSocketService::Impl::onClientConnection(fd_set& sets)
{
    if (FD_ISSET(socket_id_, &sets))
    {
        struct sockaddr_un client_addr;
        memset(&client_addr, 0, sizeof(client_addr));
        int32_t len = sizeof(client_addr);
        socket_t client_sock_fd = accept(socket_id_, (struct sockaddr*) & client_addr, &len);

        if (client_sock_fd > 0)
        {
            addClient(client_sock_fd, client_addr.sun_path);
            /** todo something */
        }
        else
        {
            LOGE("srv:%s, accept client info failed:%ld\n", url_addr_.c_str(), WSAGetLastError());
        }
    }    
}

void IpcSocketService::Impl::onDataReceive(fd_set& sets, OnReceiveHandler on_receive)
{
    uint8_t recvBuf[2048]; /**todo refine me, e.g. shared_ptr?? */
    for (const auto& it : client_infos_)
    {
        socket_t sock_id = it.first;
        if (FD_ISSET(sock_id, &sets))
        {
            int32_t retSize = ::recv(sock_id, (char*)recvBuf, SOCK_RECV_SIZE, 0);
            if (retSize > 0)
            {
                if (on_receive)
                {
                    on_receive(sock_id, recvBuf, retSize);
                }
            }
            else if (retSize == 0)
            {/** target client is disconnected from current service */
                delClient(sock_id);
            }
            else
            {/** error occur */
                if (errno != EINTR)
                {
                    delClient(sock_id);
                }
                /** todo something */
            }
        }
    }
}

///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
core::IcppErrc IpcSocketClient::Impl::start(OnConnectHandler on_connect, OnReceiveHandler on_receive)
{
    if (is_connected_)
    {
        return core::IcppErrc::InvalidStatus;
    }
    on_connect_ = on_connect;
    core::IcppErrc ret = doConnect(CONNECT_TIMEOUT_MS);

    if (core::IcppErrc::OK == ret)
    {
        ret = startReceive(on_receive);

        if (core::IcppErrc::OK != ret)
        {
            SocketClose(socket_id_);
            is_connected_ = false;               
        }
    }

    return ret;
}

core::IcppErrc IpcSocketClient::Impl::stop()
{
    if (!is_connected_)
    {
        return core::IcppErrc::InvalidStatus;
    }

    core::IcppErrc ret = stopReceive();

    if (core::IcppErrc::OK == ret)
    {
        SocketClose(socket_id_);
        is_connected_ = false;     
    }

    return ret;
}

core::IcppErrc IpcSocketClient::Impl::startReceive(OnReceiveHandler on_receive)
{
    if (is_running_)
    {
        return core::IcppErrc::InvalidStatus;
    }

    is_running_ = true;
    receive_tread_ = new std::thread([this, on_receive]{
        uint8_t receiveBuf[2048];
        while (this->is_running_)
        {
            if (this->socket_id_ != INVALID_SOCKET)
            {
                int retLen = ::recv(this->socket_id_ , (char*)receiveBuf, 2048, 0);

                if (retLen > 0)
                {
                    if (on_receive)
                    {
                        on_receive(receiveBuf, retLen);
                    }
                }
                else
                {
                    if (errno != EINTR)
                    { /**disconnected between client and service */
                        closesocket(this->socket_id_);
                        this->socket_id_ = INVALID_SOCKET;  /** todo refine me?? */
                        this->is_connected_ = false;
                        onConnectStateChange(this->is_connected_);
                    }
                }
            }
            else
            {
                //icpp::core::SleepMs(40);
                if (this->is_auto_connect_)
                {
                    this->onReconnect(CONNECT_TIMEOUT_MS); /** will max block 2 second */
                }
            }         
        }        
    });

    return core::IcppErrc::OK;
}

core::IcppErrc IpcSocketClient::Impl::stopReceive()
{
    if (!is_running_)
    {
        return core::IcppErrc::InvalidStatus;
    }    

    if (receive_tread_)
    {
        is_running_ = false;
        receive_tread_->join();
        delete receive_tread_;
        receive_tread_ = nullptr;
    }

    return core::IcppErrc::OK;
}

void IpcSocketClient::Impl::onConnectStateChange(bool is_connected)
{
    if (on_connect_)
    {
        on_connect_(socket_id_, is_connected_);
    }
}

void IpcSocketClient::Impl::onReconnect(int32_t timeout_ms)
{
    SocketClose(socket_id_);
    is_connected_ = false;
    doConnect(timeout_ms);

}

core::IcppErrc IpcSocketClient::Impl::doConnect(int32_t timeout_ms)
{
    SocketInfo sock_info;
    core::IcppErrc ret = SocketCreate(url_addr_, true, sock_info, this);

    if (core::IcppErrc::OK == ret)
    {
        socket_id_ = sock_info.socket_id;
        ::connect(socket_id_, (struct sockaddr*) & (sock_info.sock_addr), sock_info.name_len);

        fd_set Write, Err;
        FD_ZERO(&Write);
        FD_ZERO(&Err);
        FD_SET(socket_id_, &Write);
        FD_SET(socket_id_, &Err);

        TIMEVAL timeval = { 0 };
        timeval.tv_sec = timeout_ms / 1000;
        timeval.tv_usec = (timeout_ms % 1000) * 1000;

        // check if the socket is ready
        int32_t res = select(0, nullptr, &Write, &Err, &timeval);

        if (res > 0)
        {
            is_connected_ = true;
            onConnectStateChange(true);
            return core::IcppErrc::OK;
        }
        else if (res == 0)
        {
            /** todo something */  
            return core::IcppErrc::Timeout;
        }
        else
        {
            LOGE("connect failed with error: %ld\n", WSAGetLastError());        
        }   

        return core::IcppErrc::Undefined;   
    }

    return ret;
}


///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

IpcSocketService::IpcSocketService(const std::string& addr)
    :impl_(new Impl)
{
    /** todo something */
    impl_->url_addr_ = addr;
}

IpcSocketService::~IpcSocketService()
{
    /** todo something */
    impl_->stop();
}


core::IcppErrc IpcSocketService::start(OnReceiveHandler on_receive, OnAddClientHandler on_add_client, OnDelClientHandler on_del_client)
{
    return impl_->start(on_add_client, on_del_client, on_receive);
}

core::IcppErrc IpcSocketService::stop()
{
    return impl_->stop();
}

core::IcppErrc IpcSocketService::send(int client_id, const uint8_t* buf, uint32_t len)
{
    return impl_->write(client_id, buf, len);
}  

///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////

IpcSocketClient::IpcSocketClient(const std::string& addr, bool is_auto_reconect)
    :impl_(new Impl)
{
    /** todo something */
    impl_->url_addr_ = addr;
    impl_->is_auto_connect_ = is_auto_reconect;
}

IpcSocketClient::~IpcSocketClient()
{
    /** todo something */
    impl_->stop();
}


core::IcppErrc IpcSocketClient::start(OnConnectHandler on_connect, OnReceiveHandler on_receive)
{
    return impl_->start(on_connect, on_receive);
}

core::IcppErrc IpcSocketClient::stop()
{
    return impl_->stop();
}

core::IcppErrc IpcSocketClient::send(const uint8_t* buf, uint32_t len)
{
    if (impl_->is_connected_)
    {
        return SocketSend(impl_->socket_id_, buf, len);
    }

    return core::IcppErrc::InvalidStatus;
}  

} /** namespace com */
} /** namespace icpp */

#endif /** defined(_MSC_VER) */

